#!/usr/bin/env python
# Copyright (c) 2016, Konstantinos Kamnitsas
# All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the BSD license. See the accompanying LICENSE file
# or read the terms at https://opensource.org/licenses/BSD-3-Clause.

from __future__ import absolute_import, print_function, division
import sys
import os
import argparse
import traceback

sys.setrecursionlimit(20000)

from deepmedic.frontEnd.configParsing.utils import getAbsPathEvenIfRelativeIsGiven

from deepmedic.frontEnd.configParsing.modelConfig import ModelConfig
from deepmedic.frontEnd.configParsing.trainConfig import TrainConfig
from deepmedic.frontEnd.configParsing.testConfig import TestConfig

from deepmedic.frontEnd.trainSession import TrainSession
from deepmedic.frontEnd.testSession import TestSession

from deepmedic.frontEnd.configParsing.modelParams import ModelParameters

from tensorflow.python.client import device_lib


OPT_MODEL = "-model"
OPT_TRAIN = "-train"
OPT_TEST = "-test"
OPT_LOAD = "-load"

OPT_DEVICE = "-dev"
ARG_CPU_PROC = "cpu"
ARG_GPU_PROC = "cuda"
DEF_DEV_PROC = ARG_CPU_PROC

OPT_RESET = "-resetopt"


def str_is_int(s):
    try: 
        int(s)
        return True
    except ValueError:
        return False
    
def setup_arg_parser() :
    parser = argparse.ArgumentParser( prog='DeepMedic', formatter_class=argparse.RawTextHelpFormatter,
    description="\nThis software allows creation and supervised training of 3D, multi-scale CNN models for segmentation of structures in biomedical NIFTI volumes.\n"+\
                "The project is hosted at: https://github.com/Kamnitsask/deepmedic \n"+\
                "See the documentation for details on its use.\n"+\
                "This software accompanies the research presented in:\n"+\
                "Kamnitsas et al, \"Efficient Multi-Scale 3D CNN with Fully Connected CRF for Accurate Brain Lesion Segmentation\", Biomedical Image Analysis, 2016.\n"+\
                "We hope our work aids you in your endeavours.\n"+\
                "For questions and feedback contact: konstantinos.kamnitsas12@ic.ac.uk")
    
    parser.add_argument(OPT_MODEL, dest='model_cfg', type=str, help="Specify the architecture of the model to be used, by providing a config file [MODEL_CFG].")
    parser.add_argument(OPT_TRAIN, dest='train_cfg', type=str, help="Train a model with training parameters given by specifying config file [TRAINING_CFG].\n"+\
                                                                    "This option must follow a ["+OPT_MODEL+" MODEL_CFG] option, so that architecture of the to-train model is specified.\n"+\
                                                                    "Additionally, an existing checkpoint of the model can be specified in the [TRAIN_CFG] file or by the additional option ["+OPT_LOAD+"], to continue training it.")
    parser.add_argument(OPT_TEST, dest='test_cfg', type=str, help="Test with an existing model. The testing session's parameters should be given in config file [TEST_CFG].\n"+\
                                                                    "This option must follow a ["+OPT_MODEL+" MODEL_CFG] option, so that architecture of the model is specified.\n"+\
                                                                    "Existing pretrained model can be specified in the given [TEST_CFG] file or by the additional option ["+OPT_LOAD+"].\n"+\
                                                                    "This option cannot be used in combination with ["+OPT_MODEL+"] or ["+OPT_TRAIN+"].")
    parser.add_argument(OPT_LOAD, dest='saved_model', type=str, help="The path to a saved existing checkpoint with learnt weights of the model, to train or test with.\n"+\
                                                                    "This option must follow a ["+OPT_TRAIN+"] or ["+OPT_TEST+"] option.\n"+\
                                                                    "If given, this option will override any \"model\" parameters given in the [TRAIN_CFG] or [TEST_CFG] files.")
    parser.add_argument(OPT_DEVICE, default = DEF_DEV_PROC, dest='device', type=str,  help="Specify the device to run the process on. Values: [" + ARG_CPU_PROC + "] or [" + ARG_GPU_PROC + "] (default = " + DEF_DEV_PROC + ").\n"+\
                                                                    "In the case of multiple GPUs, specify a particular GPU device with a number, in the format: " + OPT_DEVICE + " " + ARG_GPU_PROC + "0 \n"+\
                                                                    "NOTE: For GPU processing, CUDA libraries must be first added in your environment's PATH and LD_LIBRARY_PATH. See accompanying documentation.")
    parser.add_argument(OPT_RESET, dest='reset_trainer', action='store_true', help="Use optionally with a ["+OPT_TRAIN+"] command. Does not take an argument.\n"+\
                                                                    "Usage: ./deepMedicRun " + OPT_MODEL + " /path/to/model/config "+OPT_TRAIN+" /path/to/train/config "+OPT_RESET+" ...etc...\n"+\
                                                                    "Resets the model\'s optimization state before starting the training session (eg number of epochs already trained, current learning rate etc).\n"+\
                                                                    "IMPORTANT: Trainable parameters are NOT reinitialized! \n"+\
                                                                    "Useful to begin a secondary training session with new learning-rate schedule, in order to fine-tune a previously trained model (Doc., Sec. 3.2)")
    
    return parser


def check_dev_passed_correctly(devArg) :
    if devArg == ARG_CPU_PROC: return
    if devArg == ARG_GPU_PROC: return
    if devArg.startswith(ARG_GPU_PROC) and str_is_int(devArg[ len(ARG_GPU_PROC) : ]) : return
    
    print(  "ERROR: Value for the [" + OPT_DEVICE + "] option was not specified correctly. Specify the device to run the process on. \n"+\
            "\tValues: [" + ARG_CPU_PROC + "] or [" + ARG_GPU_PROC + "] (Default = " + DEF_DEV_PROC + ").\n"+\
            "\tIn the case of multiple GPUs, specify a particular GPU device with a number, in the format: " + ARG_GPU_PROC + "2. Exiting.")
    exit(1)
    
    
def set_environment(dev_string):
    # Setup cpu / gpu devices.
    sess_device = None
    if dev_string == ARG_CPU_PROC:
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
        sess_device = "/CPU:0"
    elif dev_string == ARG_GPU_PROC:
        sess_device = None # With None, TF will get all cuda devices and assign to the first.
    if dev_string.startswith(ARG_GPU_PROC) and str_is_int( dev_string[ len(ARG_GPU_PROC): ] ):
        os.environ["CUDA_VISIBLE_DEVICES"] = dev_string[ len(ARG_GPU_PROC) : ]
        sess_device = "/device:GPU:0"
        
    return (sess_device)

#################################################
#                        MAIN                   #
#################################################
if __name__ == '__main__':
    cwd = os.getcwd()
    parser = setup_arg_parser()
    args = parser.parse_args()
    
    if len(sys.argv) == 1:
        print("For help on the usage of this program, please use the option -h."); exit(1)
        
    if not args.model_cfg :
        print("ERROR: Option ["+OPT_MODEL+"] must be specified, pointing to a [MODEL_CFG] file that describes the architecture.\n"+\
              "Please try [-h] for more information. Exiting."); exit(1)
    if not (args.train_cfg or args.test_cfg) :
        print("ERROR: One of the options must be specified:\n"+\
              "\t["+OPT_TRAIN+"] to start a training session on a model.\n"+\
              "\t["+OPT_TEST+"] to test with an existing model.\n"+\
              "Please try [-h] for more information. Exiting."); exit(1)
        
    #Preliminary checks:
    if args.test_cfg and args.train_cfg:
        print("ERROR:\t["+OPT_TEST+"] cannot be used in conjuction with ["+OPT_TRAIN+"].\n"+\
              "\tTo test with an existing network, please just specify a configuration file for the testing process, which will include a path to a trained model, or specify a model with ["+OPT_LOAD+"].. Exiting."); exit(1)
              
    if args.reset_trainer and not args.train_cfg :
        print("ERROR:\tThe option ["+OPT_RESET+"] can only be used together with the ["+OPT_TRAIN+"] option.\n\tPlease try -h for more information. Exiting."); exit(1)
        
    
    # Parse main files.
    if args.model_cfg:
        abs_path_model_cfg = getAbsPathEvenIfRelativeIsGiven(args.model_cfg, cwd)
        model_cfg = ModelConfig( abs_path_model_cfg )
        
    # Create session.
    if args.train_cfg:
        abs_path_train_cfg = getAbsPathEvenIfRelativeIsGiven(args.train_cfg, cwd)
        session = TrainSession( TrainConfig(abs_path_train_cfg) )
    elif args.test_cfg:
        abs_path_test_cfg = getAbsPathEvenIfRelativeIsGiven(args.test_cfg, cwd)
        session = TestSession( TestConfig(abs_path_test_cfg) )
        
    #Create output folders and logger.
    session.make_output_folders()
    session.setup_logger()
    
    log = session.get_logger()
    
    log.print3("")
    log.print3("======================== Starting new session ============================")
    log.print3("Command line arguments given: \n" + str(args) )
    
    check_dev_passed_correctly(args.device)
    (sess_device) = set_environment(args.device)
    log.print3("Available devices to Tensorflow:\n" + str(device_lib.list_local_devices()))
    
    try:
        #Find out what session we are being asked to perform:
        if args.model_cfg: # Should be always true.
            log.print3("CONFIG: The configuration file for the [model] given is: " + str( model_cfg.get_abs_path_to_cfg() ))
            model_params = ModelParameters( log, model_cfg )
            model_params.print_params()
            
        # Sessions
        log.print3("CONFIG: The configuration file for the [session] was loaded from: " + str(session.get_abs_path_to_cfg() ))
        session.override_file_cfg_with_cmd_line_cfg(args)
        _ = session.compile_session_params_from_cfg(model_params)
        
        if args.train_cfg:
            session.run_session(sess_device, model_params, args.reset_trainer)
        elif args.test_cfg:
            session.run_session(sess_device, model_params)
        # All done.
    except (Exception, KeyboardInterrupt) as e:
        log.print3("")
        log.print3("ERROR: Caught exception from main process: " + str(e) )
        log.print3( traceback.format_exc() )
        
    log.print3("Finished.")
    
        
